<template>
  <div>
    <div style="display: flex; flex-direction: row; padding: 10px" id="vueapp">
      <img src="" alt="" class="test1" style="width: 200px" />
      <div class="card">
        <div class="card-header">此处写数字</div>
        <div class="card-body">
          <canvas
            ref="drawCanvas"
            width="200"
            height="200"
            @mousedown="canvasMouseDownHandler"
            @mousemove="canvasMouseMoveHandler"
            @mouseup="canvasMouseUpHandler"
            style="border-style: dashed; display: block"
          ></canvas>
          <div style="text-align: center">
            <button
              class="btn btn-primary"
              style="margin-top: 10px"
              @click="btnClearCanvasClickedHandler"
            >
              清空
            </button>
          </div>
        </div>
        <div class="card-header">图像数据预览</div>
        <div class="card-body" style="text-align: center; background-color: black">
          <canvas
            width="28"
            height="28"
            style="border-style: solid; border-color: white"
            ref="previewCanvas"
            class="test"
          ></canvas>
        </div>
      </div>
      <div class="card" style="margin-left: 10px">
        <div class="card-header">训练</div>
        <div class="card-body">
          关联数字：
          <input type="text" v-model="targetNum" />
          <button class="btn btn-primary" @click="btnTrainClickedHandler">训练</button>

          <div>
            <div v-html="trainStatus"></div>
          </div>
        </div>
        <div class="card-header">识别</div>
        <div class="card-body">
          <button class="btn btn-primary" @click="btnPredictClickedHandler">预测</button>
          <div>{{ result }}</div>
        </div>
      </div>
    </div>
  </div>
</template>

<script>
//识别
import * as cocossd from "@tensorflow-models/coco-ssd";
//回复
// import * as mobilenet from "@tensorflow-models/qna";

// import "https://unpkg.com/@tensorflow/tfjs"
import * as tf from "@tensorflow/tfjs";
export default {
  data() {
    return {
      targetNum: 0,
      trainStatus: "",
      result: "",
    };
  },

  mounted() {
    //
    let c2d = (this.drawCanvasContext2d = this.$refs.drawCanvas.getContext("2d"));
    c2d.lineWidth = 20;
    c2d.lineCap = "round";
    c2d.lineJoin = "round";

    this.previewCanvasContext2d = this.$refs.previewCanvas.getContext("2d");

    this.loadOrCreateModel();
  },

  methods: {
    //step1：mount的第一步：创造模型
    async loadOrCreateModel() {
      try {
        this.model = await tf.loadLayersModel("localstorage://mymodel");
      } catch (e) {
        console.warn("Can not load model from LocalStorage, so we create a new model");

        this.model = tf.sequential({
          layers: [
            tf.layers.inputLayer({ inputShape: [784] }),
            tf.layers.dense({ units: 10 }),
            tf.layers.softmax(),
          ],
        });
      }
      
      this.model.compile({
        optimizer: "sgd",
        loss: "categoricalCrossentropy",
        metrics: ["accuracy"],
      });
    },

    getImageData() {
      let imageData = this.previewCanvasContext2d.getImageData(0, 0, 28, 28);
      // console.log(imageData,"imageData")

      let pixelData = [];

      let color;
      for (let i = 0; i < imageData.data.length; i += 4) {
        color = (imageData.data[i] + imageData.data[i + 1] + imageData.data[i + 2]) / 3;
        pixelData.push(Math.round((255 - color) / 255));
      }

      //blob允许我们可以通过js直接操作二进制数据，通过下面注释的这一段，我们能实现预测的时候进行下载
      // document.querySelector('.test').toBlob(function(blob) {
      //   var a = document.createElement("a");
      //   var body = document.getElementsByTagName("body");
      //   document.body.appendChild(a);
      //   a.download = "img" + ".jpg";
      //   a.href = window.URL.createObjectURL(blob);

      //   a.click();
      //   body.removeChild("a");
      // });

      return pixelData;
    },

    // step2：training训练数据,单次训练
    async btnTrainClickedHandler(e) {
      let data = this.getImageData();
  

      //目标数据处理：相当于将多个数值联合放在一起作为多个相同类型的向量
      let targetTensor = tf.oneHot(parseInt(this.targetNum), 10);

      let self = this;
      //一次训练一个数据
      console.log("Start training");
      await this.model.fit(tf.tensor([data]), tf.tensor([targetTensor.arraySync()]), {
        epochs: 30,
        callbacks: {
          onEpochEnd(epoch, logs) {
            console.log(epoch, logs);
            self.trainStatus = `<div>Step: ${epoch}</div><div>Loss: ${logs.loss}</div>`;
          },
        },
      });
      self.trainStatus = `<div style="color: green;">训练完成</div>`;
      console.log("Completed");

      await this.model.save("localstorage://mymodel");
    },

    
    async btnPredictClickedHandler(e) {
      let data = this.getImageData();

      let predictions = await this.model.predict(tf.tensor([data]));
      this.result = predictions.argMax(1).arraySync()[0];
    },

    //手写的canvas
    canvasMouseDownHandler(e) {
      this.drawing = true;
      this.drawCanvasContext2d.beginPath();
      this.drawCanvasContext2d.moveTo(e.offsetX, e.offsetY);
    },

    canvasMouseMoveHandler(e) {
      //this.drawing是点击，不然的话会沿着鼠标移动的曲线进行绘图
      if (this.drawing) {
        this.drawCanvasContext2d.lineTo(e.offsetX, e.offsetY);
        this.drawCanvasContext2d.stroke();
      }
    },

    canvasMouseUpHandler(e) {
      this.drawing = false;

      this.previewCanvasContext2d.fillStyle = "white";
      this.previewCanvasContext2d.fillRect(0, 0, 28, 28);
      this.previewCanvasContext2d.drawImage(this.$refs.drawCanvas, 0, 0, 28, 28);
    },

    btnClearCanvasClickedHandler(e) {
      this.drawCanvasContext2d.clearRect(
        0,
        0,
        this.$refs.drawCanvas.width,
        this.$refs.drawCanvas.height
      );
    },
  },
};
</script>

<style lang="scss" scoped></style>
